from diffusers import AutoencoderKL
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import os

class VAE():
    """
    VAE (Variational Autoencoder) class for image processing.
    """

    def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
        """
        Initialize the VAE instance.

        :param model_path: Path to the trained model.
        :param resized_img: The size to which images are resized.
        :param use_float16: Whether to use float16 precision.
        """
        self.model_path = model_path
        self.vae = AutoencoderKL.from_pretrained(self.model_path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.vae.to(self.device)

        if use_float16:
            self.vae = self.vae.half()
            self._use_float16 = True
        else:
            self._use_float16 = False

        self.scaling_factor = self.vae.config.scaling_factor
        self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self._resized_img = resized_img
        self._mask_tensor = self.get_mask_tensor()
        
    def get_mask_tensor(self):
        """
        Creates a mask tensor for image processing.
        :return: A mask tensor.
        """
        mask_tensor = torch.zeros((self._resized_img,self._resized_img))
        mask_tensor[:self._resized_img//2,:] = 1
        mask_tensor[mask_tensor< 0.5] = 0
        mask_tensor[mask_tensor>= 0.5] = 1
        return mask_tensor
            
    def preprocess_img(self,img_name,half_mask=False):
        """
        Preprocess an image for the VAE.

        :param img_name: The image file path or a list of image file paths.
        :param half_mask: Whether to apply a half mask to the image.
        :return: A preprocessed image tensor.
        """
        window = []
        if isinstance(img_name, str):
            window_fnames = [img_name]
            for fname in window_fnames:
                img = cv2.imread(fname)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (self._resized_img, self._resized_img),
                                     interpolation=cv2.INTER_LANCZOS4)
                window.append(img)
        else:
            img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
            window.append(img)
            
        x = np.asarray(window) / 255.
        # print("0x shape:", x.shape)
        x = np.transpose(x, (3, 0, 1, 2))
        # print("1x shape:", x.shape)
        # print("self._mask_tensor shape", self._mask_tensor.shape)
        x = torch.squeeze(torch.FloatTensor(x))
        if half_mask:
            # print("_mask_tensor:", self._mask_tensor)
            # print("x:", x)

            x = x * (self._mask_tensor>0.5)
        x = self.transform(x)
        
        x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
        x = x.to(self.vae.device)

        return x

    def encode_latents(self,image):
        """
        Encode an image into latent variables.

        :param image: The image tensor to encode.
        :return: The encoded latent variables.
        """
        with torch.no_grad():
            init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
        init_latents = self.scaling_factor * init_latent_dist.sample()
        return init_latents
    
    def decode_latents(self, latents):
        """
        Decode latent variables back into an image.
        :param latents: The latent variables to decode.
        :return: A NumPy array representing the decoded image.
        """
        latents = (1/  self.scaling_factor) * latents
        image = self.vae.decode(latents.to(self.vae.dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        image = image[...,::-1] # RGB to BGR
        return image
    
    def just_decode_latents(self, latents):
        latents = (1 / self.scaling_factor) * latents
        image = self.vae.decode(latents.to(self.vae.dtype)).sample
        return image
        
    def get_latents_for_unet(self,img):
        """
        Prepare latent variables for a U-Net model.
        :param img: The image to process.
        :return: A concatenated tensor of latents for U-Net input.
        """
        
        ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
        masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
        ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
        return latent_model_input
    
    def get_train_latents_for_unet(self, hal_face, ref_face):
        ref_image = self.preprocess_img(hal_face,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
        masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        ref_image = self.preprocess_img(ref_face,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
        ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
        latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
        return latent_model_input
    


if __name__ == "__main__":
    window = []
    img = cv2.imread("d:/11.png")
    # print(img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LANCZOS4)
    window.append(img)
    x = np.asarray(window) / 255.
    print("0x shape:", x.shape)
    x = np.transpose(x, (3, 0, 1, 2))
    print("1x shape:", x.shape)
    # vv = VAE()
    # _mask_tensor = vv.get_mask_tensor()
    # print("self._mask_tensor shape", _mask_tensor.shape)
    # x = torch.squeeze(torch.FloatTensor(x))

    # print("_mask_tensor:", _mask_tensor)
    # print("x:", x)


    # x = x * (_mask_tensor>0.5)
    # x = selftransform(x)
        
    # x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor

    # print(x.shape)

    # vae_mode_path = "./models/sd-vae-ft-mse/"
    # vae = VAE(model_path = vae_mode_path,use_float16=False)
    # img_path = "./results/sun001_crop/00000.png"
    
    # crop_imgs_path = "./results/sun001_crop/"
    # latents_out_path = "./results/latents/"
    # if not os.path.exists(latents_out_path):
    #     os.mkdir(latents_out_path)

    # files = os.listdir(crop_imgs_path)
    # files.sort()
    # files = [file for file in files if file.split(".")[-1] == "png"]

    # for file in files:
    #     index = file.split(".")[0]
    #     img_path = crop_imgs_path + file
    #     latents = vae.get_latents_for_unet(img_path)
    #     print(img_path,"latents",latents.size())
    #     #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
    #     #reload_tensor = torch.load('tensor.pt')
    #     #print(reload_tensor.size())
        

    